function [hat_beta, SE, reject, reject0, pval, pval0, info, info0] = implement_jointtest_app_cluster(dy, z, shares, shifters, shifters_est, Lambda_gamma, Vgamma, Jgamma, Joutcome, ind_shifters_est, critical , ols, cluster_vec, cluster_est_vec)

%This script implements the joint procedure for estimation and test used in our empirical application (allowing for clusters of shifters)
%Corresponding author: Rodrigo Adao
%Date: 09/11/2024


[Nobs, Ntheta] = size(Joutcome);
[Nsh_est, NTm] = size(shifters_est);

%Statistic estimation
    if ols == 1
       %OLS version -- normalize covariance by variance of IV
       hat_beta = ( (z'*dy)/Nobs ) /( sum(z.^2)/Nobs );  
       hat_eta = dy - hat_beta*z;
    else
       %Covariance version 
       hat_beta = (z'*dy)/Nobs ;
       hat_eta = dy - ( hat_beta/( sum(z.^2)/Nobs ) )*z;
    end  
    hat_eta0 = dy; 

%statistics from estimated parameter
    Jbeta = ( Joutcome'*z/Nobs )' ; 
    D = - Jbeta*( Jgamma^(-1) );

    
%Compute variance term
R0 = hat_eta0'*shares / Nobs; 
R = hat_eta'*shares / Nobs; 

if isempty(cluster_vec) == 1
%Shifters are iid
    Vbeta0 = (R0.^2)*(shifters.^2); 
    Vbeta = (R.^2)*(shifters.^2);
    
else
%Shifters are iid across clusters
    cluster_g = unique(cluster_vec);
    Nc = length(cluster_g);
    hatV_cluster  = zeros(Nc,1);
    hatV0_cluster = zeros(Nc,1);

    for c = 1:Nc       
        shifters_cluster = shifters(cluster_vec==cluster_g(c));

        Lambda_beta0_cluster = ( R0(cluster_vec'==cluster_g(c))' ).*shifters_cluster;
        hatV0_cluster(c) = sum(sum(Lambda_beta0_cluster*Lambda_beta0_cluster'));

        Lambda_beta_cluster = ( R(cluster_vec'==cluster_g(c))' ).*shifters_cluster;
        hatV_cluster(c) = sum(sum(Lambda_beta_cluster*Lambda_beta_cluster'));

    end
        Vbeta0 = sum(hatV0_cluster);
        Vbeta = sum(hatV_cluster);
end    

%Covariance term
R0_m = repmat( R0(1,ind_shifters_est)', 1, NTm);
Lambda_beta0 = R0_m.*shifters_est;

R_m = repmat( R(1,ind_shifters_est)', 1, NTm);
Lambda_beta = R_m.*shifters_est;

if isempty(cluster_vec) == 1
%Shifters are iid
    Vgamma_beta0_c_p = zeros(Nsh_est, Ntheta);
    Vgamma_beta_c_p  = zeros(Nsh_est, Ntheta);
    for nt=1:Ntheta
        for c = 1:Nsh_est
            Vgamma_beta0_c_p(c, nt) = sum( Lambda_beta0(c,:)' *Lambda_gamma(c,:, nt) , 'all'); 
            Vgamma_beta_c_p(c, nt)  = sum( Lambda_beta(c,:)'  *Lambda_gamma(c,:, nt) , 'all');
        end
    end

else
%Shifters are iid across clusters
    cluster_g = unique(cluster_est_vec);
    Nc = length(cluster_g);
    Vgamma_beta0_c_p = zeros(Nc, Ntheta);
    Vgamma_beta_c_p = zeros(Nc, Ntheta);
    
    for nt=1:Ntheta
        for c = 1:Nc
            Lambda_beta0_cluster = Lambda_beta0(cluster_est_vec==cluster_g(c), :);
            Lambda_beta_cluster  = Lambda_beta(cluster_est_vec==cluster_g(c), :);
            Lambda_gamma_cluster = Lambda_gamma(cluster_est_vec==cluster_g(c),:, nt);

            [nK, nT] = size(Lambda_gamma_cluster);
            Lambda_beta0_cluster = reshape(Lambda_beta0_cluster, nK*nT,1);
            Lambda_beta_cluster  = reshape(Lambda_beta_cluster, nK*nT,1);
            Lambda_gamma_cluster = reshape(Lambda_gamma_cluster, nK*nT,1);

            Vgamma_beta0_c_p(c, nt) = sum( Lambda_beta0_cluster' *Lambda_gamma_cluster , 'all'); 
            Vgamma_beta_c_p(c, nt)  = sum( Lambda_beta_cluster'  *Lambda_gamma_cluster , 'all');

        end    
    end
end

Vgamma_beta0 = sum(Vgamma_beta0_c_p, 1)';
Vgamma_beta = sum(Vgamma_beta_c_p, 1)';

%Compute estimator variance, SE, and test
    %variance
    hatV0_beta = Vbeta0 + D*Vgamma*(D') + ( D*Vgamma_beta0 + (D*Vgamma_beta0)' );
    hatV_beta = Vbeta + D*Vgamma*(D') + ( D*Vgamma_beta + (D*Vgamma_beta)' );

    %Standard error 
    SE  = ( hatV_beta.^(1/2)  );
    SE0 = ( hatV0_beta.^(1/2) );
    if ols == 1
        SE = SE/( sum(z.^2)/Nobs );
        SE0 = SE0/( sum(z.^2)/Nobs );
    end

    %rejection decision and p-value
    tstat  = hat_beta/SE;
    tstat0 = hat_beta/SE0;

    reject  = abs(tstat)  > critical;
    reject0 = abs(tstat0) > critical;

    pval  = 2*(1 - normcdf(abs(tstat) ,0,1));
    pval0 = 2*(1 - normcdf(abs(tstat0),0,1));

%Informativeness statistics
    info0 = ( (Vgamma_beta0')*(Vgamma^(-1))*(Vgamma_beta0) )/(Vbeta0);
    info  = ( (Vgamma_beta' )*(Vgamma^(-1))*(Vgamma_beta ) )/(Vbeta );


end